{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"align-center\">\n",
    "<a href=\"https://oumi.ai/\"><img src=\"https://oumi.ai/docs/en/latest/_static/logo/header_logo.png\" height=\"200\"></a>\n",
    "\n",
    "[![Documentation](https://img.shields.io/badge/Documentation-latest-blue.svg)](https://oumi.ai/docs/en/latest/index.html)\n",
    "[![Discord](https://img.shields.io/discord/1286348126797430814?label=Discord)](https://discord.gg/oumi)\n",
    "[![GitHub Repo stars](https://img.shields.io/github/stars/oumi-ai/oumi)](https://github.com/oumi-ai/oumi)\n",
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/oumi-ai/oumi/blob/main/configs/projects/halloumi/halloumi_classifier_inference_notebook.ipynb\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
    "</div>\n",
    "\n",
    "👋 Welcome to Open Universal Machine Intelligence (Oumi)!\n",
    "\n",
    "🚀 Oumi is a fully open-source platform that streamlines the entire lifecycle of foundation models - from [data preparation](https://oumi.ai/docs/en/latest/resources/datasets/datasets.html) and [training](https://oumi.ai/docs/en/latest/user_guides/train/train.html) to [evaluation](https://oumi.ai/docs/en/latest/user_guides/evaluate/evaluate.html) and [deployment](https://oumi.ai/docs/en/latest/user_guides/launch/launch.html). Whether you're developing on a laptop, launching large scale experiments on a cluster, or deploying models in production, Oumi provides the tools and workflows you need.\n",
    "\n",
    "🤝 Make sure to join our [Discord community](https://discord.gg/oumi) to get help, share your experiences, and contribute to the project! If you are interested in joining one of the community's open-science efforts, check out our [open collaboration](https://oumi.ai/community) page.\n",
    "\n",
    "⭐ If you like Oumi and you would like to support it, please give it a star on [GitHub](https://github.com/oumi-ai/oumi)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# HallOumi Inference\n",
    "\n",
    "This notebook demonstrates how you can run inference locally, using the HallOumi 8B classifier. \n",
    "\n",
    "Note that this is the **classifier (non-generative)** flavor of the HallOumi family.\n",
    "This model lacks the ergonomics of the generative HallOumi (per-sentence classification, citations, explanation), but is much more computationally efficient. \n",
    "It can be a good alternative when compute costs and latency are important for your use case.\n",
    "If you are interested in the generative version of HallOumi, please see [this notebook](https://github.com/oumi-ai/oumi/blob/main/configs/projects/halloumi/halloumi_inference_notebook.ipynb) instead.\n",
    "\n",
    "For more details on HallOumi, please read our [GitHub documentation](https://github.com/oumi-ai/oumi/blob/main/configs/projects/halloumi/README.md) and our [technical overview](https://oumi.ai/blog/posts/introducing-halloumi)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prerequisites\n",
    "\n",
    "Please install the following packages before the inference walkthrough."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install transformers\n",
    "%pip install torch\n",
    "%pip install scipy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference Walkthrough"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset\n",
    "\n",
    "Let's start by defining a toy dataset, where each example consists of a `context` and a `claim`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "toy_dataset = [\n",
    "    {\n",
    "        \"context\": \"Today is a sunny day.\",\n",
    "        \"claim\": \"It is not raining today.\",\n",
    "    },\n",
    "    {\n",
    "        \"context\": \"James is a software engineer. He works at a tech company.\",\n",
    "        \"claim\": \"James loves his tech job.\",\n",
    "    },\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We then convert these examples to a list of prompts. To do so, we leverage the HallOumi classifier's prompt template (`PROMPT_TEMPLATE`) that is shown below, and replace the `context` and `claims` variables with the specific example's parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "PROMPT_TEMPLATE = \"<context>\\n{context}\\n</context>\\n\\n<claims>\\n{claim}\\n</claims>\"\n",
    "\n",
    "prompts = []\n",
    "for example in toy_dataset:\n",
    "    prompt = PROMPT_TEMPLATE.format(context=example[\"context\"], claim=example[\"claim\"])\n",
    "    prompts.append(prompt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Inference\n",
    "\n",
    "#### Loading Model\n",
    "\n",
    "Next, we load the [HallOumi classifier model](https://huggingface.co/oumi-ai/HallOumi-8B-classifier) (`oumi-ai/HallOumi-8B-classifier`) from HuggingFace, together with its corresponding tokenizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "690640801eaa4ab9b32deef5e01c6aa5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(\n",
    "    \"oumi-ai/HallOumi-8B-classifier\"\n",
    ")\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"oumi-ai/HallOumi-8B-classifier\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Tokenizing\n",
    "\n",
    "Using the tokenizer instantiated above, we tokenize the prompts to query HallOumi."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = tokenizer(prompts, padding=True, truncation=True, return_tensors=\"pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Running Inference\n",
    "\n",
    "The following snippet shows how to run inference and then extract the logits. \n",
    "Finally, we apply softmax to get the probabilities."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from scipy.special import softmax\n",
    "\n",
    "with torch.no_grad():\n",
    "    outputs = model(**inputs)\n",
    "\n",
    "logits = outputs.logits\n",
    "probabilities = softmax(logits.numpy(), axis=-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Inspecting the results\n",
    "\n",
    "The last step is to iterate on the results. As shown below, HallOumi correctly identified the first example as a non-hallucination (with probability 100%-10%=90%) and the second example as a hallucination (with probability 99%)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Context: Today is a sunny day.\n",
      "Claim: It is not raining today.\n",
      "Prediction: `Non-Hallucination`\n",
      "Hallucination probability: 10%\n",
      "\n",
      "Context: James is a software engineer. He works at a tech company.\n",
      "Claim: James loves his tech job.\n",
      "Prediction: `Hallucination`\n",
      "Hallucination probability: 99%\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for probability, example in zip(probabilities, toy_dataset):\n",
    "    # The hallucination class is 1 (positive probability).\n",
    "    hallucination_prob = probability[1]\n",
    "    prediction = probability.argmax()\n",
    "    prediction_str = \"Hallucination\" if prediction == 1 else \"Non-Hallucination\"\n",
    "\n",
    "    # print the results for inspection.\n",
    "    print(f\"Context: {example['context']}\")\n",
    "    print(f\"Claim: {example['claim']}\")\n",
    "    print(f\"Prediction: `{prediction_str}`\")\n",
    "    print(f\"Hallucination probability: {hallucination_prob:.0%}\\n\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "oumi",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
